import abc
import scalevi.distributions.distributions_branched_base as dists_branched_base


class VarFam(abc.ABC):

    @abc.abstractmethod
    def get_params(self, params):
        pass

    @abc.abstractmethod
    def initial_params(self):
        pass

class VarDist(VarFam):
    def __init__(self, z_dim):
        self.z_dim = z_dim
    @abc.abstractmethod
    def sample(self, rng_key, params, chunk):
        pass
    @abc.abstractmethod
    def log_prob(self, z, params, chunk):
        pass

class VarDistWithSampleEval(VarFam):
    def __init__(self, z_dim):
        self.z_dim = z_dim
    @abc.abstractmethod
    def sample_and_log_prob(self, rng_key, params, chunk):
        pass


class VarBranchDist(dists_branched_base.SimpleBranchDist, VarFam):
    def __init__(self, N_chunk, z_dim):
        self.z_dim = z_dim
        super(VarBranchDist, self).__init__(N_chunk)

    @abc.abstractmethod        
    def get_params_parent(self, params):
        pass

    @abc.abstractmethod        
    def get_params_child(self, params, θ, chunk):
        pass

    def get_params(self, params, which_param, θ, chunk):
        if which_param=="parent":
            return self.get_params_parent(params)
        elif which_param=="child":
            return self.get_params_child(params, θ, chunk) 
        else:
            raise ValueError(
                            "Expected which_param to be either "
                            f"'parent' or 'child' but got {which_param}")

class VarBranchDistWithSampleEval(dists_branched_base.SimpleBranchDistWithSampleEval, VarFam):
    def __init__(self, N_chunk, z_dim):
        self.z_dim = z_dim
        super(VarBranchDistWithSampleEval, self).__init__(N_chunk)

    @abc.abstractmethod        
    def get_params_parent(self, params, **kwargs):
        pass

    @abc.abstractmethod        
    def get_params_child(self, params, θ, chunk, **kwargs):
        pass

    def get_params(self, params, which_param, θ, chunk, **kwargs):
        if which_param=="parent":
            return self.get_params_parent(params, **kwargs)
        elif which_param=="child":
            return self.get_params_child(params, θ, chunk, **kwargs) 
        else:
            raise ValueError(
                            "Expected which_param to be either "
                            f"'parent' or 'child' but got {which_param}")

